<config lang="json">
    {
      "name": "tools_Jamin_Lebedeff",
      "type": "native-python",
      "version": "0.0.2",
      "api_version": "0.1.19",
      "description": "Reconstruct an RGB image acquired with a cellphone camera on a Jamin Lebedeff Microscope.",
      "tags": [],
      "ui": [],
      "inputs": null,
      "outputs": null,
      "flags": [],
      "icon": null,
      "env": null,
      "runnable": false,
      "requirements": ["numpy", "tensorflow", "matplotlib", "h5py", "pillow", "scikit-image"],
      "dependencies": []
    }
    </config>
    
    
    
    <script lang="python">
    import numpy as np
    import base64
    import asyncio
    import os
    import time
    from skimage import io
    from PIL import Image
    from io import BytesIO
    
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import h5py
    from mpl_toolkits.mplot3d.axes3d import Axes3D
    import NanoImagingPack as nip
    nip.config.__DEFAULTS__['IMG_VIEWER']='NIP_VIEW'
    #%load_ext autoreload
    #%autoreload 2
    mpl.rc('figure',  figsize=(8, 4))
    mpl.rc('image', cmap='gray')
    
    import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt
    plt.ioff()
    
    class ImJoyPlugin():
        def setup(self):
            self.window = None
    
        async def run(self, my):
            api.showStatus('[UC2] Python worker to reconstruct a hologram')
            

            # cast the parameters from the GUI
            config = my.config
            print(config)
            reg_tv = 1#config.reg_tv


        async def process_rgb(self, base64_url):
            ''' Bridging javascript with Python '''
            await api.showMessage('Starting the Processing...')

            # remove encoding header
            data = base64_url.split('base64,')[1]
            
            # decode base64 to bytes
            img_bytes = base64.b64decode(data)
            buffer = BytesIO(img_bytes)
            
            # open the image as PIL image
            img = Image.open(buffer)
            arr = np.array(img)
            print('Mysize: ')
            print(arr.shape)
            
            # process the image
            await self.jlrecon(arr, reg_tv=1.)
    

        async def jlrecon(self, myjl, reg_tv):
            ''' Variable Declaration '''
    
            mydatafolder = './'
            mymatfile = './data/JL/JL_tensorflow_19_5_20.mat' # Exported MAT-file, contains OPD/R/G/B-Maps
            myimagefile = './data/JL/IMG_20190520_135218.jpg' # together with /JL_tensorflow_19_5_20.mat'
    
    
            mydatafolder = '/Users/bene/Dropbox/Dokumente/Promotion/PROJECTS/Jamin.Lebedeff/PYTHON/TF_JL_TVMin/data/HUAWEI/2019_05_21-Pont/'
            myimagefile = 'IMG_20190520_144336.jpg'
            mymatfile = 'JL_tensorflow.mat'
        
            '''Start Code here'''
            myroisize = 512 # a subregion of the image will be cropped around the center in order to reduce the computational time
            use_mask = False # do you want to use a mas for restricting the recovery?
    
            # Fitting-related
            n_poly = 11 # order of polynomial fitted to the data
            opdmax = 1650 # maximum optical path-difference in [nm]
            use_matlab = False # USe values stored in matlab .MAT?
    
            # Optmization-related 
            lambda_tv = reg_tv # TV-parameter
            epsC = 1e-2 # TV-parameter
            #lambda_neg = 100 # Negative/Positive penalty
            lr = 100 # learning-rate
            Niter = 200 # number of iteration
            is_debug = False # Show results while iterating? 
    
            # GUI output
            await api.showMessage('The TV-regularizer lambda is: '+str(lambda_tv))


            Ndisplay_text = 10
            Ndisplay = 20
            '''Load image and crop it'''
            
            #%%
            if np.mean(myjl) == 0:
                myimage = plt.imread(mydatafolder + myimagefile)
            else:
                myimage = myjl
                print(myjl.shape)
            
            myimage_size = myimage.shape
            myimage = myimage[myimage_size[0]//2-myroisize//2:myimage_size[0]//2+myroisize//2,
                                myimage_size[1]//2-myroisize//2:myimage_size[1]//2+myroisize//2, :]
            myimage = np.float32(myimage)
    
            ''' Preload MATLAB Data '''
            #%% load system data; new MATLAB v7.3 Format! 
            mat_matlab_data = h5py.File(mydatafolder+mymatfile, 'r')
            OPD_mask = np.squeeze(np.array(mat_matlab_data['mask_mat']))
            myopd_res_matlab = np.squeeze(np.array(mat_matlab_data['OPDMap_mat']))
            B_map = np.squeeze(np.array(mat_matlab_data['B_mat']))
            R_map = np.squeeze(np.array(mat_matlab_data['R_mat']))
            G_map = np.squeeze(np.array(mat_matlab_data['G_mat']))
            OPD_map = np.squeeze(np.array(mat_matlab_data['OPD_mat']))
    
            if(use_matlab): # use MATLAB value
                I_exp = np.array(mat_matlab_data['I_ref_mat']) # Value we want to reconstruct
            else:
                I_exp = np.transpose(myimage, (2,0,1)) # Assuming CXY 
    
            nopdsteps = OPD_map.shape[0] # Number of the quantized OPD-RGB look-up values
            mulfac = opdmax/nopdsteps # scale the OPD according to the matlab values 
    
            if(use_matlab):
                myinitopd = mulfac*myopd_res_matlab 
            else:
                myinitopd = findOPD(I_exp,R_map,G_map,B_map,mulfac)
                OPD_mask = np.ones(myinitopd.shape) # we don't want a mask here
    
            ''' Fit the RGB data to a spline to remove possible noise '''    
            # fit function using scipy
            R_fit_func = np.poly1d(np.polyfit(np.squeeze(OPD_map), R_map, n_poly))
            G_fit_func = np.poly1d(np.polyfit(np.squeeze(OPD_map), G_map, n_poly))
            B_fit_func = np.poly1d(np.polyfit(np.squeeze(OPD_map), B_map, n_poly))
    
            # eval fit
            R_fit = polyeval((OPD_map), np.squeeze(R_fit_func.coeffs))
            G_fit = polyeval((OPD_map), np.squeeze(G_fit_func.coeffs))
            B_fit = polyeval((OPD_map), np.squeeze(B_fit_func.coeffs))
    
            # Test the result with given minimal norm solution and fitted data
            RGB_result_matlab = np.zeros((myopd_res_matlab.shape[0],myopd_res_matlab.shape[1],3))
            RGB_result_matlab[:,:,0] = polyeval(mulfac*myopd_res_matlab, np.squeeze(R_fit_func.coeffs))
            RGB_result_matlab[:,:,1] = polyeval(mulfac*myopd_res_matlab, np.squeeze(G_fit_func.coeffs))
            RGB_result_matlab[:,:,2] = polyeval(mulfac*myopd_res_matlab, np.squeeze(B_fit_func.coeffs))
    
    
            ''' TENSORFLOW STARTS HERE '''
            # =============================================================================
            # #%% Formulate the imaging model 
            # =============================================================================
            # Basically we want to get the minimum for 
            # ||Ax-f|| +  TV(sqrt((R(OPD)-R'(OPD))**2-(G(OPD)-G'(OPD))**2-(B(OPD)-B'(OPD))**2))
    
            # seperate calibration arrays into RGB
            R_exp = np.squeeze(np.float32(I_exp[0,:,:]))
            G_exp = np.squeeze(np.float32(I_exp[1,:,:]))
            B_exp = np.squeeze(np.float32(  I_exp[2,:,:]))
            mysize = I_exp.shape
    
            # Convert Image to Tensorflow objects 
            TF_R_exp = tf.constant(R_exp)
            TF_G_exp = tf.constant(G_exp)
            TF_B_exp = tf.constant(B_exp)
            TF_R_map = tf.constant(np.squeeze(R_map)) 
            TF_G_map = tf.constant(np.squeeze(G_map))
            TF_B_map = tf.constant(np.squeeze(B_map))
            OPD_mask_flat = tf.Variable(np.reshape(OPD_mask, OPD_mask.shape[0]*OPD_mask.shape[1])) 
    
            # Placeholder for the learningrate 
            TF_lr = tf.placeholder(tf.float32, shape=[])
            TF_lambda_TV = tf.placeholder(tf.float32, shape=[])
            TF_epsC = tf.placeholder(tf.float32, shape=[])
    
            # This is the matlab reconstruction (Minimum Norm) #TODO: We want to compute this in Python too!
            TF_opd = tf.Variable(myinitopd) 
    
            # We only want to update the inner part of the mask (where OPD_mask is greater than 0)
            updates = tf.boolean_mask(TF_opd, OPD_mask>0) # TF_opd*OPD_mask
            indexes = tf.cast(tf.where(OPD_mask > 0), tf.int32)
            TF_opd_masked = tf.scatter_nd(indexes, updates , tf.shape(OPD_mask))
    
            # Compute the "Guess" based on the Variable OPD
            TF_R_guess = polyeval(TF_opd_masked, np.squeeze(R_fit_func.coeffs))
            TF_G_guess = polyeval(TF_opd_masked, np.squeeze(G_fit_func.coeffs))
            TF_B_guess = polyeval(TF_opd_masked, np.squeeze(B_fit_func.coeffs))
    
            ''' formulate cost-fct 1:'''
            # we want to add a smootheness constraint on the result coming from L2 minimization, 
            # This is done by adding TV-regularizer on the indexed image
            # This one should reduce the L2 distance between the RGB Pixels to the one in the 
            # RGB-OPD lookup-table
            TF_mySqrError = tf.reduce_mean(((TF_R_guess-TF_R_exp)**2 + 
                                        (TF_G_guess-TF_G_exp)**2 + 
                                        (TF_B_guess-TF_B_exp)**2))
    
            # in order to have a smooth phase without discontinuities we want to have a small TV norm
            TF_myTVError = TF_lambda_TV * Reg_TV(TF_opd_masked, epsC=TF_epsC)
    
    
            # we don't want to have negative values 
            lambda_neg = 1000
            TF_neg_loss = lambda_neg * Reg_NegSqr(TF_opd_masked) # clip values if out of range (no suppot)
            TF_pos_loss = lambda_neg * Reg_PosSqr(TF_opd_masked-np.max(OPD_map)) # clip values if out of range (no suppot)
    
            # Combined loss
            TF_myError = TF_mySqrError + TF_myTVError # this works best! 
    
            ''' Define optimizers '''
            TF_opt_l2 = tf.train.AdamOptimizer(learning_rate = TF_lr)
            TF_loss_l2 = TF_opt_l2.minimize(TF_mySqrError)
            TF_opt_TV = tf.train.AdamOptimizer(learning_rate = TF_lr)
            TF_loss_TV = TF_opt_TV.minimize(TF_myTVError)
            TF_opt = tf.train.AdamOptimizer(learning_rate = TF_lr)
            TF_loss = TF_opt_TV.minimize(TF_myError)
    
            #%%
            ''' start optimization part here '''
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
    
            myopd_old = sess.run(TF_opd)
    
    
            for i in range(Niter):
                # Alternating? - Better not! 
                #my_loss_l2,_ = sess.run([TF_mySqrError, TF_loss_l2], feed_dict={TF_lr:lr})
                #my_loss_tv,_ = sess.run([TF_myTVError,TF_loss_TV], feed_dict={TF_lr:lr})
    
                # combined loss works best
                my_loss_tv,my_loss_l2,_ = sess.run([TF_myTVError,TF_mySqrError,TF_loss], feed_dict={TF_lr:lr, TF_lambda_TV:lambda_tv, TF_epsC:epsC}) 
                
                #my_loss_tv,my_loss_l2,_ = sess.run([TF_myTVError,TF_mySqrError,TF_loss], feed_dict={TF_lr:lr}) 
                if(not np.mod(i,Ndisplay_text)):
                    mystring = "My Loss L2: @iter: "+str(i)+" is: "+str(my_loss_l2)+", My Loss TV: "+str(my_loss_tv)
                    await api.showMessage(mystring)
                    print(mystring)

            # Plot the result and save as png            
            myopd_new = sess.run(TF_opd_masked)
            plt.imshow(myopd_new,cmap='gray')
            name_plot = 'MY_JL_result.png'
            plt.savefig(name_plot,dpi=300)
    
            with open(name_plot, 'rb') as f:
                data = f.read()
                result = base64.b64encode(data).decode('ascii')
                imgurl = 'data:image/png;base64,' + result
                data = {"src": imgurl}
                data_plot = {
                    'name':'Plot charts: show png',
                    'type':'imjoy/image',
                    'w':12, 'h':15,
                    'data':data}
    
            # Check if window was defined
            if self.window is None:
                self.window = await api.createWindow(data_plot)
                print(f'Window created')
            else:
                print(f'Update window.')
                try:
                    await self.window.run(data=data)
                except:
                    self.window = await api.createWindow(data_plot)
                    print('Could not print to old window. New window created.')
            
            await api.showMessage('Done!')
    
    api.export(ImJoyPlugin())
    
    
    
    def polyeval(x,coeff):
        # standard polynomial function for N-Th grade polynomial
        y = 0
        npoly = len(coeff)-1
        for i in np.linspace(npoly, 0, npoly+1):
            mycoeff = coeff[np.int32(i)]
            # print(str(mycoeff) + '*x^' + str(npoly-i))
            y += mycoeff*(x**(npoly-i))
        return y
    
    def tf_abssqr(input):
        return tf.real(input*tf.conj(input))
    
    def findOPD(RGBImg,R,G,B,OPDMax):
        # Minimal Norm solution for the OPD recovery from a look-up table
    
        # flatten arrays 
        R = np.expand_dims(np.expand_dims(R,0),0)
        G = np.expand_dims(np.expand_dims(G,0),0)
        B = np.expand_dims(np.expand_dims(B,0),0)    
    
        # Measure minimum L2 distance to closest Colorvalue
        R_val = np.repeat(np.expand_dims(RGBImg[0,:,:],-1), R.shape[-1], 2)
        G_val = np.repeat(np.expand_dims(RGBImg[1,:,:],-1), R.shape[-1], 2)
        B_val = np.repeat(np.expand_dims(RGBImg[2,:,:],-1), R.shape[-1], 2)
        
        myErr = (R_val - R)**2 + (G_val - G)**2 + (B_val - B)**2     
    
        OPDMap = np.argmin(myErr, axis=-1);
        return np.float32(OPDMap*OPDMax)
            
        # https://github.com/tensorflow/tensorflow/issues/18383
    
    
    def Reg_TV(toRegularize, BetaVals = [1,1], epsR = 1, epsC=1e-10, is_circ = True):
        # used rainers version to realize the tv regularizer   
        #% The Regularisation modification with epsR was introduced, according to
        #% Ferreol Soulez et al. "Blind deconvolution of 3D data in wide field fluorescence microscopy
        #%
        #
        #function [myReg,myRegGrad]=RegularizeTV(toRegularize,BetaVals,epsR)
        #epsC=1e-10;
    
        
        if(is_circ):
            aGradL_1 = (toRegularize - tf.manip.roll(toRegularize, 1, 0))/BetaVals[0]
            aGradL_2 = (toRegularize - tf.manip.roll(toRegularize, 1, 1))/BetaVals[1]
    
            aGradR_1 = (toRegularize - tf.manip.roll(toRegularize, -1, 0))/BetaVals[0]
            aGradR_2 = (toRegularize - tf.manip.roll(toRegularize, -1, 1))/BetaVals[1]
            
            print('We use circular shift for the TV regularizer')
        else:    
            toRegularize_sub = toRegularize[1:-2,1:-2,1:-2]
            aGradL_1 = (toRegularize_sub - toRegularize[2:-1,1:-2,1:-2])/BetaVals[0] # cyclic rotation
            aGradL_2 = (toRegularize_sub - toRegularize[1:-1-1,2:-1,1:-1-1])/BetaVals[1] # cyclic rotation
            
            aGradR_1 = (toRegularize_sub - toRegularize[0:-3,1:-2,1:-2])/BetaVals[0] # cyclic rotation
            aGradR_2 = (toRegularize_sub - toRegularize[1:-2,0:-3,1:-2])/BetaVals[1] # cyclic rotation
                
        mySqrtL = tf.sqrt(tf_abssqr(aGradL_1)+tf_abssqr(aGradL_2)+epsR)
        mySqrtR = tf.sqrt(tf_abssqr(aGradR_1)+tf_abssqr(aGradR_2)+epsR)
        
        mySqrt = mySqrtL + mySqrtR; 
        
        if(1):
            mySqrt = tf.where(
                        tf.less(mySqrt , epsC*tf.ones_like(mySqrt)),
                        epsC*tf.ones_like(mySqrt),
                        mySqrt) # To avoid divisions by zero
        else:               
            mySqrt = mySqrt # tf.clip_by_value(mySqrt, 0, np.inf)    
            
    
            
        myReg = tf.reduce_mean(mySqrt)
    
        return myReg
    
    
    def Reg_NegSqr(toRegularize):
        mySqrt = tf.where( # Just affects the real part
                        tf.less(toRegularize , tf.zeros_like(toRegularize)),
                        tf_abssqr(toRegularize), tf.zeros_like(toRegularize))
         
        myReg = tf.reduce_mean(mySqrt)
        return myReg
     
    def Reg_PosSqr(toRegularize):
        mySqrt = tf.where( # Just affects the real part
                        tf.greater(toRegularize , tf.zeros_like(toRegularize)),
                        tf_abssqr(toRegularize), tf.zeros_like(toRegularize))
         
        myReg = tf.reduce_mean(mySqrt)
        return myReg
    
    
        
    </script>
    